# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s

# In GraphDef custom gradient functions are modeled using GradientDef which
# links the function and its gradient. In MLIR a TF ops gradient function is
# added to its list of function attributes.

# CHECK: func private @foo0(
# CHECK:   tf.gradient = @foo_grad

node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 0.25
      }
    }
  }
  experimental_debug_info {
  }
}
node {
  name: "foo"
  op: "foo"
  input: "Const"
  attr {
    key: "_disable_call_shape_inference"
    value {
      b: true
    }
  }
  experimental_debug_info {
  }
}
node {
  name: "gradients/Shape"
  op: "Shape"
  input: "foo"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "out_type"
    value {
      type: DT_INT32
    }
  }
  experimental_debug_info {
  }
}
node {
  name: "gradients/grad_ys_0"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1
      }
    }
  }
  experimental_debug_info {
  }
}
node {
  name: "gradients/Fill"
  op: "Fill"
  input: "gradients/Shape"
  input: "gradients/grad_ys_0"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "index_type"
    value {
      type: DT_INT32
    }
  }
  experimental_debug_info {
  }
}
node {
  name: "gradients/foo_grad/SymbolicGradient"
  op: "SymbolicGradient"
  input: "Const"
  input: "gradients/Fill"
  attr {
    key: "Tin"
    value {
      list {
        type: DT_FLOAT
        type: DT_FLOAT
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_FLOAT
      }
    }
  }
  attr {
    key: "f"
    value {
      func {
        name: "foo"
        attr {
          key: "_disable_call_shape_inference"
          value {
            b: true
          }
        }
      }
    }
  }
  experimental_debug_info {
  }
}
library {
  function {
    signature {
      name: "foo"
      input_arg {
        name: "foo"
        type: DT_FLOAT
      }
      output_arg {
        name: "foo1"
        type: DT_FLOAT
      }
    }
    node_def {
      name: "Exp"
      op: "Exp"
      input: "foo"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
      experimental_debug_info {
        original_node_names: "Exp"
      }
    }
    node_def {
      name: "Neg"
      op: "Neg"
      input: "foo"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
      experimental_debug_info {
        original_node_names: "Neg"
      }
    }
    node_def {
      name: "Exp_1"
      op: "Exp"
      input: "Neg:y:0"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
      experimental_debug_info {
        original_node_names: "Exp_1"
      }
    }
    node_def {
      name: "sub_0"
      op: "Sub"
      input: "Exp:y:0"
      input: "Exp_1:y:0"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
      experimental_debug_info {
        original_node_names: "sub_0"
      }
    }
    ret {
      key: "foo1"
      value: "sub_0:z:0"
    }
    attr {
      key: "_disable_call_shape_inference"
      value {
        b: true
      }
    }
  }
  function {
    signature {
      name: "foo_grad"
      input_arg {
        name: "foo_grad"
        type: DT_FLOAT
      }
      input_arg {
        name: "foo_grad1"
        type: DT_FLOAT
      }
      output_arg {
        name: "foo_grad2"
        type: DT_FLOAT
      }
    }
    node_def {
      name: "mul_0"
      op: "Mul"
      input: "foo_grad"
      input: "foo_grad1"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
      experimental_debug_info {
        original_node_names: "mul_0"
      }
    }
    ret {
      key: "foo_grad2"
      value: "mul_0:z:0"
    }
    attr {
      key: "_disable_call_shape_inference"
      value {
        b: true
      }
    }
  }
  gradient {
    function_name: "foo"
    gradient_func: "foo_grad"
  }
}
versions {
  producer: 29
  min_consumer: 12
}
